from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

from pyrallis import field


@dataclass
class ContextEncoderTrainConfig:
    # wandb params
    project: str = "OSRL-baselines"
    group: str = None
    name: Optional[str] = None
    prefix: Optional[str] = "ContextEncoder"
    suffix: Optional[str] = ""
    logdir: Optional[str] = "logs"
    verbose: bool = True
    # dataset params
    outliers_percent: float = None
    noise_scale: float = None
    inpaint_ranges: Tuple[Tuple[float, float], ...] = None
    epsilon: float = None
    density: float = 1.0
    # model params
    state_encoding_dim: int = 32
    action_encoding_dim: int = 2
    context_encoding_dim: int = 16
    context_encoder_hidden_sizes: List[float] = field(default=[256, 256, 256], is_mutable=True)
    feature_hidden_sizes: List[float] = field(default=[256, 128], is_mutable=True)
    decoder_hidden_sizes: List[float] = field(default=[128, 128], is_mutable=True)
    simple_gate: bool = False
    simple_mlp: bool = False
    state_loss_weight: float = 1.0
    reward_loss_weight: float = 0.0
    cost_loss_weight: float = 2.0
    context_size: int = 500
    vis_nums: int = 1000
    # training params
    # task: str = "OfflineCarCircle-v0"
    learning_rate: float = 1e-4
    min_learning_rate: float = 3e-5
    betas: Tuple[float, float] = (0.9, 0.999)
    decay_step: int = 2000
    decay_rate: float = 0.9
    clip_grad: Optional[float] = 0.25
    batch_size: int = 2048
    meta_batch_per_task: int = 4
    update_steps: int = 100_000
    eval_every: int = 1000
    lr_warmup_steps: int = 500
    reward_scale: float = 0.1
    cost_scale: float = 1
    num_workers: int = 8
    # general params
    seed: int = 0
    device: str = "cuda:5"
    threads: int = 6
    # augmentation param
    deg: int = 4
    pf_sample: bool = False
    beta: float = 1.0
    augment_percent: float = 0.2
    # maximum absolute value of reward for the augmented trajs
    max_reward: float = 600.0
    # minimum reward above the PF curve
    min_reward: float = 1.0
    # the max drecrease of ret between the associated traj
    # w.r.t the nearest pf traj
    max_rew_decrease: float = 100.0
    # pf only mode param
    pf_only: bool = False
    rmin: float = 300
    cost_bins: int = 60
    npb: int = 5
    cost_sample: bool = True
    linear: bool = True  # linear or inverse
    start_sampling: bool = False
    prob: float = 0.2
    stochastic: bool = True
    init_temperature: float = 0.1
    no_entropy: bool = False
    # random augmentation
    random_aug: float = 0
    aug_rmin: float = 400
    aug_rmax: float = 500
    aug_cmin: float = -2
    aug_cmax: float = 25
    cgap: float = 5
    rstd: float = 1
    cstd: float = 0.2
    state_encoder_paths : List[str] = field(default=[
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ], is_mutable=True)
    action_encoder_paths : List[str] = field(default=[
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ], is_mutable=True)
